#!/usr/bin/env python3

import argparse
import os
import sys
import queue
import hashlib
import threading
from pathlib import Path

# find direct_io in the same directory as this program
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import direct_io # for pread(), allocate_aligned_buffers(), round_up()

def parse_and_validate_args():
    """Parse command-line arguments. Always assume recursive for directories."""
    parser = argparse.ArgumentParser(
        description="Compute per-GB MD5 checksums of each file (recursive if directory)."
    )

    parser.add_argument("paths", nargs="+", help="Files or directories to process")
    parser.add_argument("-n", "--num-threads", type=int, default=16,
                        help="Number of threads (default: 16)")
    parser.add_argument("-b", "--buffer-size", type=int, default=256 * 1024 * 1024,
                        help="Buffer size in bytes (default: 256MiB, must be a multiple of 2MiB and evenly divide 1GiB)")

    args = parser.parse_args()

    # Validate buffer size
    if args.buffer_size % (2 * 1024 * 1024) != 0:
        sys.exit("Error: buffer size must be a multiple of 2MiB")
    if (1 << 30) % args.buffer_size != 0:
        sys.exit("Error: buffer size must evenly divide 1GiB")

    # Validate paths
    for p in args.paths:
        if not os.path.exists(p):
            sys.exit(f"{os.path.basename(sys.argv[0])}: cannot stat '{p}': No such file or directory")

    return args

def enqueue_file_chunks(display_name, actual_path, chunk_size, workpile):
    """Resolve the path, get file size and fs block size, and enqueue each chunk."""
    size = os.path.getsize(actual_path)
    block_size = os.statvfs(actual_path).f_bsize
    for offset in range(0, size, chunk_size):
        this_chunk = min(chunk_size, size - offset)
        workpile.put((display_name, offset, this_chunk, block_size))

def checksum_worker(thread_id, buffer, workpile, buffer_size, print_lock):
    """Worker function to compute md5 checksums of file chunks."""
    while True:
        try:
            filepath, offset, size, fs_block_size = workpile.get_nowait()
        except queue.Empty:
            return

        md5 = hashlib.md5()

        with open(filepath, 'rb') as f:
            fd = f.fileno()
            remaining = size
            local_offset = offset
            while remaining > 0:
                this_read = min(remaining, buffer_size)
                bytes_read = direct_io.pread(fd, buffer, this_read, local_offset, fs_block_size, thread_id)
                assert bytes_read == this_read, f"pread returned {bytes_read}, expected {this_read}"
                md5.update(buffer[:this_read])
                local_offset += this_read
                remaining -= this_read

        with print_lock:
            print(f"{filepath}\t{offset}\t{size}\t{md5.hexdigest()}")
        workpile.task_done()

def main():
    args = parse_and_validate_args()
    workpile = queue.Queue()
    CHUNK_SIZE = 1 << 30  # 1 GiB

    for root in args.paths:
        original_arg = root
        path = Path(original_arg)
        if path.is_dir():
            for file in path.rglob('*'):
                if file.is_file() or file.is_symlink():
                    display_name = os.path.join(original_arg, os.path.relpath(file, path))
                    enqueue_file_chunks(display_name, file, CHUNK_SIZE, workpile)
        elif path.is_file() or path.is_symlink():
            enqueue_file_chunks(original_arg, path, CHUNK_SIZE, workpile)

    # Prepare aligned buffers for each thread
    alignment = 2 * 1024 * 1024  # 2 MiB
    assert alignment >= 512, "alignment smaller than 512 bytes"
    thread_buffers, raw_buf = direct_io.allocate_aligned_buffers(
        args.buffer_size,
        alignment,
        args.num_threads
    )

    # Launch worker threads
    print_lock = threading.Lock()
    threads = []
    for i in range(args.num_threads):
        t = threading.Thread(target=checksum_worker, args=(i, thread_buffers[i], workpile, args.buffer_size, print_lock))
        t.start()
        threads.append(t)

    for t in threads:
        t.join()

if __name__ == '__main__':
    main()
